// -*- mode: c++ -*-
/// \file intersect.hh
/// \brief Intersection computations.

#ifndef CGMATH_INTERSECT_HH
#define CGMATH_INTERSECT_HH

namespace cgmath
{
  /// \brief Check if a/b >= 0.
  ///
  /// This is in intersection functions to check if a/b >= 0 or not without
  /// doing the division.
  ///
  /// \todo Find out the best way to do this. I guess bit fiddling for ints, 
  /// floats in SSE regs? Probably branching for doubles in FPU.
  ///
  template<typename scalar>
  bool same_sign (const scalar& a, const scalar& b)
  {
    return (a * b) >= 0;
  }

  /// Intersection between ray and hyperplane.
  template<typename scalar, unsigned n>
  bool intersect (const ray<scalar, n>& r,
		  const hyperplane<scalar,n>& p, 
		  double& t)
  {
    scalar dn = dot(r.d, p.n);
    if (abs(dn) >= 0.0)
      {
	scalar pn = dot(r.p, p.n);
	double _t = -(p.d + pn) / dn;
	if (_t >= 0.0)
	  {
	    t = _t;
	    return true;
	  }
      }
    return false;
  }

  /// Intersection between line segment and triangle.
  template<typename scalar>
  bool intersect (const line_segment<scalar,3>& l,
		  const triangle<scalar,3>&     T,
		  double& t)
  {
    // TODO: reuse parts of determinants, compiler won't do it probably.
    vector_n<scalar,3> v = l.a - l.b;
    scalar ABv = det(T.a, T.b, v);
    scalar BCv = det(T.b, T.c, v);
    scalar ACv = det(T.a, T.c, v);
    scalar ABCv = BCv - ACv + ABv;
    if (fabs(ABCv) < std::numeric_limits<scalar>::epsilon())
      {
	// TODO: handle coplanar case by projection
	return false;
      }
    scalar PBv  = det(l.a, T.b, v);
    scalar PCv  = det(l.a, T.c, v);
    scalar PBCv = BCv - PCv + PBv;
    if (not same_sign (PBCv, ABCv)) return false; // a negative
    scalar APv  = det(T.a, l.a, v);
    scalar APCv = PCv - ACv + APv;
    if (not same_sign (APCv, ABCv)) return false; // b negative
    if (not same_sign (ABCv - PBCv - APCv, ABCv)) return false; // c negative
    // This part is probably too slow
    scalar BCP  = det(T.b, T.c, l.a);
    scalar ACP  = det(T.a, T.c, l.a);
    scalar ABP  = det(T.a, T.b, l.a);
    scalar ABC  = det(T.a, T.b, T.c);
    scalar ABCP = BCP - ACP + ABP - ABC;
    if (not same_sign (ABCP, ABCv)) return false; // outside start
    if (fabs(ABCP) > fabs(ABCv))    return false; // outside end
    t = static_cast<double>(ABCP)/static_cast<double>(ABCv);
    return true;
  }

  /// Intersection between line segment and triangle.
  /// Outputs also barycentric coordinates of hit point.
  template<typename scalar>
  bool intersect (const line_segment<scalar,3>& l,
		  const triangle<scalar,3>&     T,
		  double& t,
		  double& a,
		  double& b,
		  double& c)
  {
    vector_n<scalar,3> v = l.a - l.b;
    scalar ABv = det(T.a, T.b, v);
    scalar BCv = det(T.b, T.c, v);
    scalar ACv = det(T.a, T.c, v);
    scalar ABCv = BCv - ACv + ABv;
    if (fabs(ABCv) < std::numeric_limits<scalar>::epsilon())
      {
	// TODO: handle coplanar case by projection
	return false;
      }
    scalar PBv  = det(l.a, T.b, v);
    scalar PCv  = det(l.a, T.c, v);
    scalar PBCv = BCv - PCv + PBv;
    if (not same_sign (PBCv, ABCv)) return false; // a negative
    scalar APv  = det(T.a, l.a, v);
    scalar APCv = PCv - ACv + APv;
    if (not same_sign (APCv, ABCv)) return false; // b negative
    if (not same_sign (ABCv - PBCv - APCv, ABCv)) return false; // c negative
    // This part is probably too slow
    scalar BCP  = det(T.b, T.c, l.a);
    scalar ACP  = det(T.a, T.c, l.a);
    scalar ABP  = det(T.a, T.b, l.a);
    scalar ABC  = det(T.a, T.b, T.c);
    scalar ABCP = BCP - ACP + ABP - ABC;
    if (not same_sign (ABCP, ABCv)) return false; // outside start
    if (fabs(ABCP) > fabs(ABCv))    return false; // outside end
    double iABCv = 1.0 / static_cast<double>(ABCv);
    t = iABCv * ABCP;
    a = iABCv * PBCv; // Barycentric coordinates of intersection point.
    b = iABCv * APCv;
    c = 1.0 - a - b;
    return true;
  }

  /// Intersection between line segment and triangle.
  /// Outputs also barycentric coordinates of hit point.
  template<typename scalar>
  bool intersect (const ray<scalar,3>&      r,
		  const triangle<scalar,3>& T,
		  double& t,
		  double& a,
		  double& b,
		  double& c)
  {
    // vector_n<scalar,3> v = r.d;
    vector_n<scalar,3> v = -r.d;
    scalar ABv = det(T.a, T.b, v);
    scalar BCv = det(T.b, T.c, v);
    scalar ACv = det(T.a, T.c, v);
    scalar ABCv = BCv - ACv + ABv;
    if (fabs(ABCv) < std::numeric_limits<scalar>::epsilon())
      {
	// TODO: handle coplanar case by projection
	return false;
      }
    scalar PBv  = det(r.p, T.b, v);
    scalar PCv  = det(r.p, T.c, v);
    scalar PBCv = BCv - PCv + PBv;
    if (not same_sign (PBCv, ABCv)) return false; // a negative
    scalar APv  = det(T.a, r.p, v);
    scalar APCv = PCv - ACv + APv;
    if (not same_sign (APCv, ABCv)) return false; // b negative
    if (not same_sign (ABCv - PBCv - APCv, ABCv)) return false; // c negative
    // This part is probably too slow
    scalar BCP  = det(T.b, T.c, r.p);
    scalar ACP  = det(T.a, T.c, r.p);
    scalar ABP  = det(T.a, T.b, r.p);
    scalar ABC  = det(T.a, T.b, T.c);
    scalar ABCP = BCP - ACP + ABP - ABC;
    // if (fabs(ABCP) > fabs(ABCv))    return false; // outside end
    if (not same_sign(ABCP, ABCv)) return false; // outside start
    double iABCv = 1.0 / static_cast<double>(ABCv);
    t = iABCv * ABCP;
    a = iABCv * PBCv; // Barycentric coordinates of intersection point.
    b = iABCv * APCv;
    c = 1.0 - a - b;
    return true;
  }

  /// Intersection test between solid sphere and segment.
  template<typename scalar, unsigned n>
  bool intersect (const line_segment<scalar,n>& l,
		  const sphere<scalar,n>&       S)
  {
    // Note: No divisions or square roots.
    scalar d  = dot(S.c - l.a, l.b - l.a);
    scalar n2 = norm2 (l.b - l.a);
    if (same_sign(d, n2))
      {
	if (d < n2)
	  return norm2 (d*(l.b-l.a) - n2*(S.c - l.a)) <= n2*n2*S.r*S.r;
	else
	  return norm2 (S.c - l.b) <= S.r * S.r;
      }
    else
      return norm2 (S.c - l.a) <= S.r * S.r;
  }

  /// Intersection test between sphere and ray.
  template<typename scalar, unsigned n>
  bool intersect (const ray<scalar,n>&          r,
		  const sphere<scalar,n>&       S)
  {
    vector_n<scalar,n> a = S.c - r.p;
    scalar ad  = dot(a, r.d);
    scalar d2  = norm2(r.d);
    scalar a2  = norm2(a);
    scalar ad2 = ad*ad;
    scalar det = ad2 - d2*(a2 - S.r*S.r);
    if (det >= 0)
      {
	if (ad >= 0)
	  return (det + ad2) >= 0;
	else
	  return (det - ad2) >= 0;
      }
    return false;
  }
  
  /// Find nearest point to sphere on segment. In case of two intersections
  /// return one nearer to beginning of segment.
  template<typename scalar, unsigned n>
  bool nearest_point_on (const line_segment<scalar,n>& l,
			 const sphere<scalar,n>&       S,
			 double&                       t)
  {
    vector_n<scalar, n> d  = l.b - l.a;
    vector_n<scalar, n> PC = l.a - S.c;
    double d_PC      = dot(d, PC);
    double det = d_PC * d_PC - norm2(d) * (norm2(PC) - S.r*S.r);
    //std::cout << "det = " << det << std::endl;
    double nd = norm2(d);
    if (det <= 0.0)
      {
	if (-d_PC >= 0.0 && -d_PC <= nd) 
	  t = -d_PC / nd;
	else if (-d_PC < 0.0)
	  t = 0.0;
	else
	  t = 1.0;
	return false;
      }
    double sdet = std::sqrt(det);
    double t0 = -d_PC - sdet;
    if (t0 >= 0.0 && t0 <= nd) { t = t0 / nd; return true; }
    double t1 = -d_PC + sdet;
    if (t1 >= 0.0 && t1 <= nd) { t = t1 / nd; return true; }
    if (t0 > nd)
      t = 1.0;
    else if (t1 < 0.0)
      t = 0.0;
    else if (-t0 < t1)
      t = -t0 / nd;
    else
      t = t1 / nd;
    return false;
  }

  /// Find nearest point to sphere on ray. In case of two intersections
  /// return one nearer to beginning of ray.
  template<typename scalar, unsigned n>
  bool nearest_point_on (const ray<scalar,n>&          r,
			 const sphere<scalar,n>&       S,
			 double&                       t)
  {
    vector_n<scalar, n> d  = r.d;
    vector_n<scalar, n> PC = r.p - S.c;
    double d_PC      = dot(d, PC);
    double det = d_PC * d_PC - norm2(d) * (norm2(PC) - S.r*S.r);
    //std::cout << "det = " << det << std::endl;
    double nd = norm2(d);
    if (det <= 0.0)
      {
	if (-d_PC >= 0.0 && -d_PC <= nd) 
	  t = -d_PC / nd;
	else if (-d_PC < 0.0)
	  t = 0.0;
	else
	  t = 1.0;
	return false;
      }
    double sdet = std::sqrt(det);
    double t0 = -d_PC - sdet;
    if (t0 >= 0.0) { t = t0 / nd; return true; }
    double t1 = -d_PC + sdet;
    if (t1 >= 0.0) { t = t1 / nd; return true; }
    t = t0 / nd;
    return false;
  }


  /// Intersection test between two solid spheres.
  template<typename scalar, unsigned n>
  bool intersect (const sphere<scalar,n>& S,
		  const sphere<scalar,n>& R)
  {
    return distance2(S.c, R.c) <= (S.r + R.r) * (S.r + R.r);
  }

  /// Check if [a,b] and [c,d] overlap.
  template<typename scalar>
  bool overlap (const scalar& a, 
		const scalar& b, 
		const scalar& c,
		const scalar& d)
  {
    return std::max(a,c) <= std::min(b,d);
  }
	
  /// Project triangle on an axis, find interval. Used in separating axis tests.
  template<typename scalar>
  void project (const vector_n<scalar,3>& axis,
		const triangle<scalar,3>& T,
		scalar& a,
		scalar& b)
  {
    a = b = axis * T.a;
    scalar proj_b = axis * T.b;
    a = std::min (a, proj_b);
    b = std::max (b, proj_b);
    scalar proj_c = axis * T.c;
    a = std::min (a, proj_c);
    b = std::max (b, proj_c);
  }

  /// Check if an axis separates two objects by projection.
  template<typename object, typename scalar>
  bool separates (const vector_n<scalar,3>& a,
		  const object& T,
		  const object& S)
  {
    // check that axis is sane
    if (norm_max(a) <= 1e-10) return false;
    scalar Ta, Tb, Sa, Sb;
    project (a, T, Ta, Tb);
    project (a, S, Sa, Sb);
    return !overlap (Ta, Tb, Sa, Sb);
  }
  
  /// \brief Separating axis intersection test between two triangles.
  ///
  /// Checks if normals and cross products between edges separate the triangles.
  /// If not, then they must intersect.
  template<typename scalar>
  bool intersect (const triangle<scalar,3>& T,
		  const triangle<scalar,3>& S)
  {
    vector_n<scalar,3> Tn = normal(T);
    vector_n<scalar,3> Sn = normal(S);
    if (norm_max(cross(Tn,Sn)) > 1e-10) // check coplanarity
      {
	vector_n<scalar,3> T0 = T.b - T.a;
	vector_n<scalar,3> T1 = T.c - T.b;
	vector_n<scalar,3> T2 = T.a - T.c;
	vector_n<scalar,3> S0 = S.b - S.a;
	vector_n<scalar,3> S1 = S.c - S.b;
	vector_n<scalar,3> S2 = S.a - S.c;
	return !( separates (Tn, T, S) 
		  || separates (Sn, T, S) 
		  || separates (cross (T0, S0), T, S) 
		  || separates (cross (T0, S1), T, S) 
		  || separates (cross (T0, S2), T, S) 
		  || separates (cross (T1, S0), T, S) 
		  || separates (cross (T1, S1), T, S) 
		  || separates (cross (T1, S2), T, S) 
		  || separates (cross (T2, S0), T, S) 
		  || separates (cross (T2, S1), T, S) 
		  || separates (cross (T2, S2), T, S) );
      }
    else
      return false; // coplanar case => epic fail.
  }


  /// \brief Separating axis intersection test between two oriented boxes.
  ///
  /// Checks if normals and cross products between edges separate the boxes.
  /// If not, then they must intersect.
  template<typename scalar>
  bool intersect (const oriented_box<scalar,3>& T,
		  const oriented_box<scalar,3>& S)
  {
    return !( separates (T.axis(0), T, S) 
	      || separates (T.axis(1), T, S)
	      || separates (T.axis(2), T, S)
	      || separates (S.axis(0), T, S) 
	      || separates (S.axis(1), T, S)
	      || separates (S.axis(2), T, S)
	      || separates (cross(T.axis(0), S.axis(0)), T, S)
	      || separates (cross(T.axis(0), S.axis(1)), T, S)
	      || separates (cross(T.axis(0), S.axis(2)), T, S)
	      || separates (cross(T.axis(1), S.axis(0)), T, S)
	      || separates (cross(T.axis(1), S.axis(1)), T, S)
	      || separates (cross(T.axis(1), S.axis(2)), T, S)
	      || separates (cross(T.axis(2), S.axis(0)), T, S)
	      || separates (cross(T.axis(2), S.axis(1)), T, S)
	      || separates (cross(T.axis(2), S.axis(2)), T, S) );
  }

}

#endif
